"""
Artifact registry for centralized artifact storage and querying.

This module provides the ArtifactRegistry class that:
- Stores all artifacts for an execution
- Provides query API (by ID, task, type, etc.)
- Handles deduplication (by storage_path)
- Thread-safe using asyncio.Lock
- Supports cleanup of artifacts

All artifacts are execution-scoped (live for duration of execution).
"""

import asyncio
from typing import Dict, List, Optional, Set
from uuid import UUID

from loguru import logger

from roma_dspy.types import Artifact, ArtifactMetadata, ArtifactType, MediaType


class ArtifactRegistry:
    """
    Centralized registry for artifact storage and querying.

    Responsibilities:
    - Store artifacts with deduplication
    - Query artifacts by various criteria
    - Track artifact lineage
    - Provide thread-safe access

    Deduplication Strategy:
    - Key: storage_path (absolute path)
    - If same path registered twice, merge metadata (keep newer)

    Thread Safety:
    - Uses asyncio.Lock for all mutations
    - Read-only queries are lock-free (snapshot pattern)
    """

    def __init__(self):
        """Initialize empty artifact registry."""
        # Primary storage: artifact_id -> Artifact
        self._artifacts: Dict[UUID, Artifact] = {}

        # Deduplication index: storage_path -> artifact_id
        self._path_index: Dict[str, UUID] = {}

        # Query indexes
        self._task_index: Dict[str, Set[UUID]] = {}  # task_id -> {artifact_ids}
        self._type_index: Dict[
            ArtifactType, Set[UUID]
        ] = {}  # artifact_type -> {artifact_ids}
        self._media_index: Dict[
            MediaType, Set[UUID]
        ] = {}  # media_type -> {artifact_ids}

        # Thread safety
        self._lock = asyncio.Lock()

    async def register(self, artifact: Artifact) -> Artifact:
        """
        Register artifact with deduplication.

        If an artifact with the same storage_path already exists:
        - Keep the newer artifact
        - Merge metadata (preserve custom fields from both)
        - Update all indexes

        Args:
            artifact: Artifact to register

        Returns:
            Registered artifact (may be deduplicated with existing)
        """
        async with self._lock:
            # Check for duplicate by storage_path
            existing_id = self._path_index.get(artifact.storage_path)

            if existing_id:
                # Deduplication: merge with existing
                existing = self._artifacts[existing_id]
                merged = self._merge_artifacts(existing, artifact)

                # Update primary storage
                self._artifacts[existing_id] = merged

                # Indexes already point to existing_id, no update needed

                logger.debug(
                    "Deduplicated artifact",
                    artifact_id=str(existing_id),
                    path=artifact.storage_path,
                )

                return merged
            else:
                # New artifact: add to all indexes
                artifact_id = artifact.artifact_id

                # Primary storage
                self._artifacts[artifact_id] = artifact

                # Deduplication index
                self._path_index[artifact.storage_path] = artifact_id

                # Query indexes
                if artifact.created_by_task not in self._task_index:
                    self._task_index[artifact.created_by_task] = set()
                self._task_index[artifact.created_by_task].add(artifact_id)

                if artifact.artifact_type not in self._type_index:
                    self._type_index[artifact.artifact_type] = set()
                self._type_index[artifact.artifact_type].add(artifact_id)

                if artifact.media_type not in self._media_index:
                    self._media_index[artifact.media_type] = set()
                self._media_index[artifact.media_type].add(artifact_id)

                logger.debug(
                    "Registered artifact",
                    artifact_id=str(artifact_id),
                    name=artifact.name,
                    artifact_type=artifact.artifact_type,
                    task=artifact.created_by_task,
                )

                return artifact

    async def register_batch(self, artifacts: List[Artifact]) -> List[Artifact]:
        """
        Register multiple artifacts in a single transaction.

        More efficient than calling register() in a loop because:
        - Acquires lock once (not N times)
        - Atomic operation (all artifacts registered together)
        - Better performance for bulk operations (2-5x faster)

        Args:
            artifacts: List of artifacts to register

        Returns:
            List of registered artifacts (deduplicated, same order as input)
        """
        if not artifacts:
            return []

        async with self._lock:
            registered = []
            newly_registered_count = 0
            deduplicated_count = 0

            for artifact in artifacts:
                # Check for duplicate by storage_path
                existing_id = self._path_index.get(artifact.storage_path)

                if existing_id:
                    # Deduplication: merge with existing
                    existing = self._artifacts[existing_id]
                    merged = self._merge_artifacts(existing, artifact)

                    # Update primary storage
                    self._artifacts[existing_id] = merged
                    registered.append(merged)
                    deduplicated_count += 1

                    logger.debug(
                        "Deduplicated artifact in batch",
                        artifact_id=str(existing_id),
                        path=artifact.storage_path,
                    )
                else:
                    # New artifact: add to all indexes
                    artifact_id = artifact.artifact_id

                    # Primary storage
                    self._artifacts[artifact_id] = artifact

                    # Deduplication index
                    self._path_index[artifact.storage_path] = artifact_id

                    # Query indexes
                    if artifact.created_by_task not in self._task_index:
                        self._task_index[artifact.created_by_task] = set()
                    self._task_index[artifact.created_by_task].add(artifact_id)

                    if artifact.artifact_type not in self._type_index:
                        self._type_index[artifact.artifact_type] = set()
                    self._type_index[artifact.artifact_type].add(artifact_id)

                    if artifact.media_type not in self._media_index:
                        self._media_index[artifact.media_type] = set()
                    self._media_index[artifact.media_type].add(artifact_id)

                    registered.append(artifact)
                    newly_registered_count += 1

            logger.info(
                f"Batch registered {len(registered)} artifact(s)",
                newly_registered=newly_registered_count,
                deduplicated=deduplicated_count,
                total_requested=len(artifacts),
            )

            return registered

    def _merge_artifacts(self, existing: Artifact, new: Artifact) -> Artifact:
        """
        Merge two artifacts (deduplication strategy).

        Strategy:
        - Keep existing artifact_id (for deduplication)
        - Use newer artifact's fields
        - Merge custom metadata (preserve both)
        - Merge derived_from (union)

        Args:
            existing: Existing artifact
            new: New artifact to merge

        Returns:
            Merged artifact with existing's ID
        """
        # Determine which is newer
        if new.created_at > existing.created_at:
            source = new
            other = existing
        else:
            source = existing
            other = new

        # Merge custom metadata
        merged_custom = {**other.metadata.custom, **source.metadata.custom}

        # Merge lineage (union)
        merged_lineage = list(set(source.derived_from + other.derived_from))

        # Create new merged artifact with existing's ID
        merged = Artifact(
            artifact_id=existing.artifact_id,  # Keep existing ID!
            name=source.name,
            artifact_type=source.artifact_type,
            media_type=source.media_type,
            storage_path=source.storage_path,
            created_by_task=source.created_by_task,
            created_by_module=source.created_by_module,
            created_at=source.created_at,
            metadata=ArtifactMetadata(
                description=source.metadata.description,
                mime_type=source.metadata.mime_type,
                size_bytes=source.metadata.size_bytes,
                row_count=source.metadata.row_count,
                column_count=source.metadata.column_count,
                data_schema=source.metadata.data_schema,
                preview=source.metadata.preview,
                usage_hints=source.metadata.usage_hints,
                custom=merged_custom,
            ),
            derived_from=merged_lineage,
        )

        return merged

    async def get_by_id(self, artifact_id: UUID) -> Optional[Artifact]:
        """
        Get artifact by ID.

        Args:
            artifact_id: Artifact UUID

        Returns:
            Artifact if found, None otherwise
        """
        return self._artifacts.get(artifact_id)

    async def get_by_path(self, storage_path: str) -> Optional[Artifact]:
        """
        Get artifact by storage path.

        Args:
            storage_path: Absolute path to artifact file

        Returns:
            Artifact if found, None otherwise
        """
        artifact_id = self._path_index.get(storage_path)
        if artifact_id:
            return self._artifacts.get(artifact_id)
        return None

    async def get_by_task(self, task_id: str) -> List[Artifact]:
        """
        Get all artifacts created by a specific task.

        Args:
            task_id: Task ID

        Returns:
            List of artifacts (may be empty)
        """
        artifact_ids = self._task_index.get(task_id, set())
        return [self._artifacts[aid] for aid in artifact_ids]

    async def get_by_type(self, artifact_type: ArtifactType) -> List[Artifact]:
        """
        Get all artifacts of a specific type.

        Args:
            artifact_type: Artifact type enum

        Returns:
            List of artifacts (may be empty)
        """
        artifact_ids = self._type_index.get(artifact_type, set())
        return [self._artifacts[aid] for aid in artifact_ids]

    async def get_by_media(self, media_type: MediaType) -> List[Artifact]:
        """
        Get all artifacts of a specific media type.

        Args:
            media_type: Media type enum

        Returns:
            List of artifacts (may be empty)
        """
        artifact_ids = self._media_index.get(media_type, set())
        return [self._artifacts[aid] for aid in artifact_ids]

    async def get_all(self) -> List[Artifact]:
        """
        Get all registered artifacts.

        Returns:
            List of all artifacts
        """
        return list(self._artifacts.values())

    async def get_lineage(self, artifact_id: UUID) -> List[Artifact]:
        """
        Get artifact lineage (ancestors).

        Args:
            artifact_id: Artifact UUID

        Returns:
            List of ancestor artifacts (may be empty)
        """
        artifact = self._artifacts.get(artifact_id)
        if not artifact:
            return []

        ancestors = []
        for parent_id in artifact.derived_from:
            parent = self._artifacts.get(parent_id)
            if parent:
                ancestors.append(parent)

        return ancestors

    async def get_descendants(self, artifact_id: UUID) -> List[Artifact]:
        """
        Get artifact descendants (artifacts derived from this one).

        Args:
            artifact_id: Artifact UUID

        Returns:
            List of descendant artifacts (may be empty)
        """
        descendants = []
        for artifact in self._artifacts.values():
            if artifact_id in artifact.derived_from:
                descendants.append(artifact)

        return descendants

    async def remove(self, artifact_id: UUID) -> bool:
        """
        Remove artifact from registry.

        Args:
            artifact_id: Artifact UUID to remove

        Returns:
            True if removed, False if not found
        """
        async with self._lock:
            artifact = self._artifacts.get(artifact_id)
            if not artifact:
                return False

            # Remove from primary storage
            del self._artifacts[artifact_id]

            # Remove from deduplication index
            if artifact.storage_path in self._path_index:
                del self._path_index[artifact.storage_path]

            # Remove from query indexes
            self._task_index[artifact.created_by_task].discard(artifact_id)
            self._type_index[artifact.artifact_type].discard(artifact_id)
            self._media_index[artifact.media_type].discard(artifact_id)

            logger.debug("Removed artifact", artifact_id=str(artifact_id))

            return True

    async def clear(self) -> int:
        """
        Clear all artifacts from registry.

        Returns:
            Number of artifacts removed
        """
        async with self._lock:
            count = len(self._artifacts)

            self._artifacts.clear()
            self._path_index.clear()
            self._task_index.clear()
            self._type_index.clear()
            self._media_index.clear()

            logger.info(f"Cleared artifact registry ({count} artifacts)")

            return count

    async def get_stats(self) -> Dict[str, int]:
        """
        Get registry statistics.

        Returns:
            Dictionary with counts by type
        """
        return {
            "total_artifacts": len(self._artifacts),
            "unique_tasks": len(self._task_index),
            "by_artifact_type": {
                artifact_type.value: len(artifact_ids)
                for artifact_type, artifact_ids in self._type_index.items()
            },
            "by_media_type": {
                media_type.value: len(artifact_ids)
                for media_type, artifact_ids in self._media_index.items()
            },
        }
